Skip to content

[Feat][Spec Decode] DFlash#36847

Open
benchislett wants to merge 36 commits intovllm-project:mainfrom
CentML:dflash-attempt2
Open

[Feat][Spec Decode] DFlash#36847
benchislett wants to merge 36 commits intovllm-project:mainfrom
CentML:dflash-attempt2

Conversation

@benchislett
Copy link
Collaborator

@benchislett benchislett commented Mar 12, 2026

Purpose

Overview

DFlash works much like P-EAGLE (see #32887), but with a major architectural change: it uses bidirectional attention between the query tokens (the last sampled token from the base model plus a bunch of placeholder mask tokens) and the context states, which are the target model's hidden states from the prefill or accepted tokens.

To implement this, I introduce an extra operation that lives outside of the main model execution, which populates the KV cache with the context states directly. Though not exposed to the standard set of torch.compile and CUDA graph optimizations, handling the context in this way allows us to use async scheduling (by writing the full set of context states to the cache and then using seq_lens_gpu to ignore the rejected ones), as well as enabling (piecewise) CUDA graphs for the main forward pass over the query tokens.

(Solved) Because the attention is bidirectional and structured in this way, we cannot allow rejected tokens to stay in the batch as their states would corrupt the rest of the calculation, unlike in standard causal attention where we can simply omit them and sample from an earlier position. Therefore, "disable_padded_drafter_batch" is required and will be enabled when using DFlash, disabling async scheduling as a consequence. At this time I cannot think of a good way around this problem.

Additionally, a small selection of kernel backends actually support non-causal attention, and a smaller set additionally include support for gpt-oss style "sinks". Flash Attention with Qwen3-8B is used as a test for this implementation, as Triton Attention and FlashInfer (TRTLLM) attention both do not support non-causality. A follow-up work here would be to allow different attention backends for the drafter and the target model. This is out-of-scope for this PR, but would enable a broader set of compatibility for DFlash models.

(Solved) Finally, this new architecture requires extra logic to handle the new input shapes and attention metadata. It is not as simple as P-EAGLE in which we can share code with other modes of speculation, since the sizes and contents of the input tensors are all different. One compatibility component in which I am not absolutely confident is the CUDA graph support: even in piecewise mode, I am not sure how the "padded" query tokens interact with the unpadded context: is it safe to have operations on the context slice of Q inside the attention op, or must it be moved into a custom op to avoid issues? It has not yet been functional to enable torch.compile for the DFlash drafter, likely for this reason.

Implementation Details

The model is implemented in qwen3_dflash.py and the core speculation logic is in dflash.py. Similarly to draft model speculative decoding, DFlash has a unique speculation paradigm that requires some refactoring of eagle.py in order to support cleanly.

Specifically, build_model_inputs_first_pass and build_per_layer_attn_metadata have been introduced in eagle.py so that dflash.py can override them.

Initial implementations inlined the logic for DFlash into eagle.py, but similarly to #24322 the added branching and complexity would (in my opinion) lead to a fairly cluttered EAGLE implementation. In this PR I have tried to separate the concerns to a reasonable degree, so that maintenance of the existing EAGLE pathway is not burdened.

Testing

I add unit tests for both correctness (GSM8k) and acceptance rate checks for Qwen3-8B with DFlash. I am able to reproduce almost exactly the acceptance rate values reported in the DFlash paper, and have included this in the test suite with an exact reproduction of their evaluation setup. This should be a valuable resource and an easily extensible tool to new DFlash models as they continue to evolve and expand the DFlash model family.

In local tests on 1xB200, both Qwen3-8B and Qwen3.5-9B pass the test suite using both FA2 and FA4 on 1xB200.

Usage

vllm serve Qwen/Qwen3-8B \
    --speculative-config '{"method": "dflash", "model": "z-lab/Qwen3-8B-DFlash-b16", "num_speculative_tokens": 15}' \
    --attention-backend flash_attn \
    --max-num-batched-tokens 32768

Benchmarking

The latest DFlash implementation on this branch is optimized for low-latency performance. I measured the following speedups using the DFlash methodology (https://github.com/z-lab/dflash)

Datasets covered are Alpaca, GSM8k, HumanEval, Math500, MBPP, MT-Bench. Number of spec tokens for DFlash is 15.

Alpaca

Output tok/s
conc  baseline    dflash  speedup
----  --------  --------  -------
  32  5,268.64  6,210.65    1.179
  16  3,079.46  5,015.65    1.629
   8  1,581.42  3,339.16    2.111
   1    209.46    590.48    2.819

GSM8k

Output tok/s
conc  baseline    dflash  speedup
----  --------  --------  -------
  32  5,261.70  8,433.19    1.603
  16  3,079.63  6,534.04    2.122
   8  1,596.48  4,332.12    2.714
   1    211.35    741.56    3.509

HumanEval

Output tok/s
conc  baseline     dflash  speedup
----  --------  ---------  -------
  32  5,002.62  10,257.62    2.050
  16  3,043.47   8,131.09    2.672
   8  1,583.48   5,433.65    3.431
   1    209.16     969.07    4.633

Math500

Output tok/s
conc  baseline    dflash  speedup
----  --------  --------  -------
  32  5,267.27  9,836.42    1.867
  16  3,087.45  7,668.98    2.484
   8  1,598.57  5,096.36    3.188
   1    211.31    849.22    4.019

MBPP

Output tok/s
conc  baseline    dflash  speedup
----  --------  --------  -------
  32  5,120.32  8,289.34    1.619
  16  3,099.78  6,677.64    2.154
   8  1,584.55  4,487.97    2.832
   1    209.56    810.25    3.866

MT-Bench

Output tok/s
conc  baseline    dflash  speedup
----  --------  --------  -------
  32  5,217.68  6,787.14    1.301
  16  3,082.95  5,297.45    1.718
   8  1,591.69  3,536.00    2.222
   1    210.29    627.01    2.982

@mergify
Copy link

mergify bot commented Mar 12, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @benchislett.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 12, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for DFlash speculative decoding, a new method that leverages bidirectional attention. The implementation is comprehensive, touching configuration, model definition, the speculative proposer, and the core model runner. Key changes include:

  • A new DFlashProposer and qwen3_dflash model to implement the DFlash architecture.
  • Refactoring of eagle.py to better support different speculative decoding methods.
  • Configuration updates for auto-detection and setup of DFlash.
  • Extensive end-to-end tests for both correctness and acceptance rate, which is great to see.

The code is well-structured, and the refactoring improves extensibility. My main concern, which you've also highlighted in the PR description and code comments, is the compatibility with CUDA graphs due to the torch.cat operation on potentially differently padded tensors. This is a critical point for performance and stability, and I've left a comment with a suggestion on how to address it.

@benchislett
Copy link
Collaborator Author

Update on the graphs, I have a local workaround that I'm working on cleaning up. The solution is to put the context states in the forward_context and access them via CustomOp similar to unified_attention_with_output. Then all the ordinary logic can go in the main graph.

@mergify
Copy link

mergify bot commented Mar 13, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @benchislett.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 13, 2026
@jianc99
Copy link

jianc99 commented Mar 18, 2026

Thanks @benchislett for implementing DFlash in vLLM. We’ve just released DFlash checkpoints for Qwen3.5-4B, 9B, 27B, and 35B-A3B. They perform very well and are faster than MTP. Feel free to try them out. We’ll continue adding support for more models, and I’m really looking forward to seeing DFlash run in vLLM!

benchislett and others added 14 commits March 19, 2026 15:58
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@mergify
Copy link

mergify bot commented Mar 19, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @benchislett.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 19, 2026
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@benchislett
Copy link
Collaborator Author

benchislett commented Mar 19, 2026

(force-pushed to fix DCO. manually checked all test cases still pass)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@mergify mergify bot removed the needs-rebase label Mar 19, 2026
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@benchislett
Copy link
Collaborator Author

benchislett commented Mar 19, 2026

@mgoin

Regarding this issue:
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
It is a side-effect of how we currently handle parallel drafting and max_num_batched_tokens. It can be resolved by increasing max_num_batched_tokens to 32768.

It will now raise a clear error on startup as of e99905a.

gpu_memory_utilization=0.85,
enforce_eager=False,
disable_log_stats=False,
attention_config={"backend": "FLASH_ATTN"}, # Required for non-causal attention
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get this working across the board without needing to specify this arg? We should be able to resolve this internally by querying the attention backend's supports_attn_type during selection

Comment on lines 856 to +860
def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "mtp")
return self.method in ("eagle", "eagle3", "mtp", "dflash")

def use_dflash(self) -> bool:
return self.method == "dflash"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: what is the point of use_eagle being used for several methods? maybe we should rename this to a more general pattern like uses_hidden_state_proposer

Comment on lines +370 to +374
normed_context_states = rmsnorm(
input=context_states,
weight=self._hidden_norm_weight,
eps=self._rms_norm_eps,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of using the flashinfer rmsnorm? I'd prefer to have this use the general rmsnorm op in vLLM, and the flashinfer version should be a backend within it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, most of this is due to my own ignorance about vLLM's kernel internals. It should be feasible to dispatch to vLLM's fused RMSNorm here. For the RoPE though, I'm not sure if we have a native fused kernel. I can look into it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely don't need the dependency on FlashInfer long-term. But was very easy in the prototype. I'll clean this up

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's certainly okay to land as-is as long as we remove the global flashinfer imports

Comment on lines +403 to +412
# In-place RoPE cannot be called here, since we use K for both query and key.
# Instead we just call the fused kernel and ignore the query output.
_, all_k_flat = apply_rope_with_cos_sin_cache(
positions=positions_repeated,
query=all_k_flat,
key=all_k_flat,
head_size=self._rope_head_size,
cos_sin_cache=self._rope_cos_sin_cache,
is_neox=self._rope_is_neox,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto here, but I understand based on the comment if this is more required. I think other than these two flashinfer kernels, everything else should not necessarily have a dependency on NVIDA GPUs. If this must be kept, we should gate this path on a CUDA platform check

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really really nice work! I think these are all the things I found for now, but I should take another look through soon

Comment on lines +8 to +9
from flashinfer import rmsnorm
from flashinfer.rope import apply_rope_with_cos_sin_cache
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto on CUDA platform check/lazy import when needed instead of unconditionally putting flashinfer import at the top of a model file

num_layers = len(parent_ref.model.layers)
return (2, num_layers // 2, num_layers - 3)

def _get_default_eagle3_aux_hidden_state_layers(self) -> tuple[int]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should be tuple[int, ...] since i think it can be variable in length

Comment on lines +141 to +152
# Save context slot_mapping for pre-insertion (clone because buffer
# will be reused by _get_slot_mapping)
self._dflash_context_slot_mapping = self._slot_mapping_buffer[
:num_context
].clone()

# Only query slot_mapping for the model forward pass — context KVs
# are pre-inserted into cache before the forward. Clone to avoid
# aliasing with the buffer that _get_slot_mapping writes into.
query_slot_mapping = self._slot_mapping_buffer[
num_context:num_all_positions
].clone()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this clones are a perf issue, I think it would be straightforward to double buffer

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch

tl.store(out_positions_ptr + pos_out_idx, positions, mask=in_bounds)

# --- Slot mapping (block_table lookup for all positions) ---
block_num = positions // block_size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is needs the same clamping used in the eagle kernel above

    # Block table lookup: block_number = position // block_size
    # Clamp block_number to avoid OOB when position is at max
    block_number = clamped_position // block_size
    block_number = tl.minimum(block_number, n_blocks_per_req - 1)

from vllm.model_executor.models.deepseek_eagle3 import Eagle3DeepseekV2ForCausalLM
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.model_executor.models.qwen3_dflash import DFlashQwen3ForCausalLM
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we should remove the top-level flashinfer import if we need to import the model class like this

@mergify
Copy link

mergify bot commented Mar 23, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @benchislett.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 23, 2026
@geraldstanje1
Copy link

hi @benchislett do you have an image of this branch for testing? would like to test it with gpt oss 20b on rtx 6000 pro.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

4 participants